#define ALIGN(I, A) (((I)+(A)-1)/(A)*(A))
! (C) Copyright 2000- ECMWF.
! (C) Copyright 2000- Meteo-France.
! (C) Copyright 2022- NVIDIA.
! 
! This software is licensed under the terms of the Apache Licence Version 2.0
! which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
! In applying this licence, ECMWF does not waive the privileges and immunities
! granted to it by virtue of its status as an intergovernmental organisation
! nor does it submit to any jurisdiction.
!

MODULE TRLTOM_PACK_UNPACK
  USE BUFFERED_ALLOCATOR_MOD, ONLY: ALLOCATION_RESERVATION_HANDLE
  USE PARKIND_ECTRANS,        ONLY: JPIM
  IMPLICIT NONE

  PRIVATE
  PUBLIC :: TRLTOM_PACK_HANDLE, PREPARE_TRLTOM_PACK, TRLTOM_PACK
  PUBLIC :: TRLTOM_UNPACK_HANDLE, PREPARE_TRLTOM_UNPACK, TRLTOM_UNPACK

  TYPE TRLTOM_PACK_HANDLE
    TYPE(ALLOCATION_RESERVATION_HANDLE) :: HFOUBUF_IN
  END TYPE
  TYPE TRLTOM_UNPACK_HANDLE
    TYPE(ALLOCATION_RESERVATION_HANDLE) :: HINPS_AND_ZINPA
  END TYPE

  INTEGER(KIND=JPIM) :: A = 8 !Alignment

CONTAINS
  FUNCTION PREPARE_TRLTOM_PACK(ALLOCATOR, KF_FS) RESULT(HTRLTOM_PACK)
    USE PARKIND_ECTRANS,        ONLY: JPIM, JPRBT
    USE TPM_DISTR,              ONLY: D
    USE ISO_C_BINDING,          ONLY: C_SIZE_T
    USE BUFFERED_ALLOCATOR_MOD, ONLY: BUFFERED_ALLOCATOR, RESERVE

    IMPLICIT NONE

    TYPE(BUFFERED_ALLOCATOR), INTENT(INOUT) :: ALLOCATOR
    INTEGER(KIND=JPIM), INTENT(IN) :: KF_FS
    TYPE(TRLTOM_PACK_HANDLE) :: HTRLTOM_PACK

    REAL(KIND=JPRBT) :: DUMMY

    HTRLTOM_PACK%HFOUBUF_IN = RESERVE(ALLOCATOR, int(D%NLENGT0B*KF_FS*2*SIZEOF(DUMMY),kind=c_size_t))
  END FUNCTION PREPARE_TRLTOM_PACK

  SUBROUTINE TRLTOM_PACK(ALLOCATOR,HTRLTOM_PACK,PREEL_COMPLEX,FOUBUF_IN,KF_FS)
    !**** *TRLTOM_PACK* - Copy fourier data from local array to buffer

    !     Purpose.
    !     --------
    !        Routine for copying fourier data from local array to buffer

    !**   Interface.
    !     ----------
    !     CALL TRLTOM_PACK(...)

    !     Explicit arguments :  PREEL - local fourier/GP array
    !     --------------------  KF_FS - number of fields
    !
    !     Externals.  None.
    !     ----------

    !     Author.
    !     -------
    !        Mats Hamrud *ECMWF*

    !     ------------------------------------------------------------------

    USE BUFFERED_ALLOCATOR_MOD, ONLY: BUFFERED_ALLOCATOR, ASSIGN_PTR, GET_ALLOCATION
    USE PARKIND_ECTRANS,        ONLY: JPIM, JPRBT
    USE TPM_DISTR,              ONLY: D, MYSETW, D_NSTAGTF, D_NPNTGTB0, D_NPTRLS, D_NDGL_FS
    USE TPM_GEOMETRY,           ONLY: G_NMEN, G_NLOEN
    USE TPM_DIM,                ONLY: R_NSMAX
    USE ISO_C_BINDING,          ONLY: C_SIZE_T
    !

    IMPLICIT NONE

    REAL(KIND=JPRBT), INTENT(IN) :: PREEL_COMPLEX(:)
    REAL(KIND=JPRBT), POINTER, INTENT(OUT) :: FOUBUF_IN(:)
    INTEGER(KIND=JPIM),INTENT(IN) :: KF_FS
    TYPE(BUFFERED_ALLOCATOR), INTENT(IN) :: ALLOCATOR
    TYPE(TRLTOM_PACK_HANDLE), INTENT(IN) :: HTRLTOM_PACK

    INTEGER(KIND=JPIM) :: JM,JF,IGLG,ISTA,OFFSET_VAR,IOFF_LAT,KGL

    REAL(KIND=JPRBT)    :: SCAL

    CALL ASSIGN_PTR(FOUBUF_IN, GET_ALLOCATION(ALLOCATOR, HTRLTOM_PACK%HFOUBUF_IN),&
        & 1_C_SIZE_T, int(D%NLENGT0B*KF_FS*2*SIZEOF(FOUBUF_IN(1)),kind=c_size_t))

#ifdef OMPGPU
#endif
#ifdef ACCGPU
    !$ACC DATA PRESENT(G_NMEN,D_NPNTGTB0,FOUBUF_IN,PREEL_COMPLEX,D_NSTAGTF,D_NDGL_FS,G_NLOEN, R_NSMAX) ASYNC(1)
#endif

    ! scale results and move into next transformation buffer

    OFFSET_VAR=D_NPTRLS(MYSETW)

#ifdef OMPGPU
#endif
#ifdef ACCGPU
    !$ACC PARALLEL LOOP PRIVATE(IGLG,IOFF_LAT,ISTA,SCAL) FIRSTPRIVATE(KF_FS,OFFSET_VAR) DEFAULT(NONE) &
    !$ACC& ASYNC(1) TILE(32,16,1)
#endif
    DO KGL=1,D_NDGL_FS
      DO JM=0,R_NSMAX !(note that R_NSMAX <= G_NMEN(IGLG) for all IGLG)
        DO JF=1,KF_FS
          IGLG = OFFSET_VAR+KGL-1
          IF (JM <= G_NMEN(IGLG)) THEN
            IOFF_LAT = KF_FS*D_NSTAGTF(KGL)+(JF-1)*(D_NSTAGTF(KGL+1)-D_NSTAGTF(KGL))

            SCAL = 1._JPRBT/REAL(G_NLOEN(IGLG),JPRBT)
            ISTA  = D_NPNTGTB0(JM,KGL)*KF_FS*2

            FOUBUF_IN(ISTA+2*JF-1) = SCAL * PREEL_COMPLEX(IOFF_LAT+2*JM+1)
            FOUBUF_IN(ISTA+2*JF  ) = SCAL * PREEL_COMPLEX(IOFF_LAT+2*JM+2)
          ENDIF
        ENDDO
      ENDDO
    ENDDO
#ifdef OMPGPU
#endif
#ifdef ACCGPU
    !$ACC END DATA

    !$ACC WAIT(1)
#endif
  END SUBROUTINE TRLTOM_PACK

  FUNCTION PREPARE_TRLTOM_UNPACK(ALLOCATOR, KF_FS) RESULT(HTRLTOM_UNPACK)
    USE PARKIND_ECTRANS,        ONLY: JPIM, JPRBT, JPRD
    USE BUFFERED_ALLOCATOR_MOD, ONLY: BUFFERED_ALLOCATOR, RESERVE
    USE LEDIR_MOD,              ONLY: LEDIR_STRIDES
    USE ISO_C_BINDING,          ONLY: C_SIZE_T

    IMPLICIT NONE

    TYPE(BUFFERED_ALLOCATOR), INTENT(INOUT) :: ALLOCATOR
    INTEGER(KIND=JPIM), INTENT(IN) :: KF_FS
    TYPE(TRLTOM_UNPACK_HANDLE) :: HTRLTOM_UNPACK

    INTEGER(KIND=JPIM)  :: IIN_STRIDES0, IIN_SIZE
    INTEGER(KIND=JPIM)  :: IIN0_STRIDES0, IIN0_SIZE
    INTEGER(KIND=C_SIZE_T)  :: ISIZE

    REAL(KIND=JPRBT) :: ZPRBT_DUMMY
    REAL(KIND=JPRD) :: ZPRD_DUMMY

    CALL LEDIR_STRIDES(KF_FS,IIN_STRIDES0=IIN_STRIDES0,IIN_SIZE=IIN_SIZE,&
                       IIN0_STRIDES0=IIN0_STRIDES0,IIN0_SIZE=IIN0_SIZE)

    ! Check if the reuse buffer is large enough
    ISIZE = ALIGN(IIN_SIZE*SIZEOF(ZPRBT_DUMMY),128)
    ISIZE = ISIZE + ALIGN(IIN_SIZE*SIZEOF(ZPRBT_DUMMY),128)
    ISIZE = ISIZE + ALIGN(IIN0_SIZE*SIZEOF(ZPRD_DUMMY),128)
    ISIZE = ISIZE + ALIGN(IIN0_SIZE*SIZEOF(ZPRD_DUMMY),128)

    HTRLTOM_UNPACK%HINPS_AND_ZINPA = RESERVE(ALLOCATOR, ISIZE)
  END FUNCTION PREPARE_TRLTOM_UNPACK

  SUBROUTINE TRLTOM_UNPACK(ALLOCATOR,HTRLTOM_UNPACK,FOUBUF,ZINPS,ZINPA,ZINPS0,ZINPA0,KF_FS,KF_UV)
    USE PARKIND_ECTRANS,             ONLY: JPIM, JPRBT, JPRD
    USE TPM_DIM,                     ONLY: R_NDGNH, R_NDGL
    USE TPM_GEOMETRY,                ONLY: G_NDGLU
    USE BUFFERED_ALLOCATOR_MOD,      ONLY: BUFFERED_ALLOCATOR, ASSIGN_PTR, GET_ALLOCATION
    USE TPM_FIELDS_FLAT,             ONLY: F_RW, F_RACTHE
    USE TPM_DISTR,                   ONLY: D_NUMP, D_MYMS, D_NPNTGTB1, D_OFFSETS_GEMM1
    USE LEDIR_MOD,                   ONLY: LEDIR_STRIDES
    USE, INTRINSIC :: ISO_C_BINDING, ONLY: C_SIZE_T

    IMPLICIT NONE

    REAL(KIND=JPRBT), INTENT(IN) :: FOUBUF(:)
    REAL(KIND=JPRBT), POINTER, INTENT(INOUT) :: ZINPS(:), ZINPA(:)
    REAL(KIND=JPRD), POINTER, INTENT(INOUT) :: ZINPS0(:), ZINPA0(:)
    INTEGER(KIND=JPIM), INTENT(IN) :: KF_FS, KF_UV
    TYPE(BUFFERED_ALLOCATOR), INTENT(IN) :: ALLOCATOR
    TYPE(TRLTOM_UNPACK_HANDLE), INTENT(IN) :: HTRLTOM_UNPACK

    REAL(KIND=JPRBT), POINTER :: PREEL_COMPLEX(:)

    INTEGER(KIND=JPIM) :: IIN_STRIDES0, IIN_SIZE
    INTEGER(KIND=JPIM) :: IIN0_STRIDES0, IIN0_SIZE

    INTEGER(KIND=C_SIZE_T) :: IALLOC_POS, IALLOC_SZ

    INTEGER(KIND=8)  :: JF
    INTEGER(KIND=JPIM) :: KM, ISL, IGLS, OFFSET1, OFFSET2, JGL, KMLOC

    REAL(KIND=JPRBT) :: PAIA, PAIS

    CALL LEDIR_STRIDES(KF_FS,IIN_STRIDES0=IIN_STRIDES0,IIN_SIZE=IIN_SIZE,&
                       IIN0_STRIDES0=IIN0_STRIDES0,IIN0_SIZE=IIN0_SIZE)

    IALLOC_POS=1

    IALLOC_SZ = ALIGN(IIN_SIZE*SIZEOF(ZINPS(0)),128)
    CALL ASSIGN_PTR(ZINPS, GET_ALLOCATION(ALLOCATOR, HTRLTOM_UNPACK%HINPS_AND_ZINPA),&
        & IALLOC_POS, IALLOC_SZ)
    IALLOC_POS=IALLOC_POS+IALLOC_SZ

    IALLOC_SZ = ALIGN(IIN_SIZE*SIZEOF(ZINPA(0)),128)
    CALL ASSIGN_PTR(ZINPA, GET_ALLOCATION(ALLOCATOR, HTRLTOM_UNPACK%HINPS_AND_ZINPA),&
        & IALLOC_POS, IALLOC_SZ)
    IALLOC_POS=IALLOC_POS+IALLOC_SZ

    IALLOC_SZ = ALIGN(IIN0_SIZE*SIZEOF(ZINPS0(0)),128)
    CALL ASSIGN_PTR(ZINPS0, GET_ALLOCATION(ALLOCATOR, HTRLTOM_UNPACK%HINPS_AND_ZINPA),&
        & IALLOC_POS, IALLOC_SZ)
    IALLOC_POS=IALLOC_POS+IALLOC_SZ

    IALLOC_SZ = ALIGN(IIN0_SIZE*SIZEOF(ZINPA0(0)),128)
    CALL ASSIGN_PTR(ZINPA0, GET_ALLOCATION(ALLOCATOR, HTRLTOM_UNPACK%HINPS_AND_ZINPA),&
        & IALLOC_POS, IALLOC_SZ)
    IALLOC_POS=IALLOC_POS+IALLOC_SZ

#ifdef OMPGPU
#endif
#ifdef ACCGPU
    !$ACC DATA &
    !$ACC& PRESENT(ZINPS,ZINPA,ZINPS0,ZINPA0) &
    !$ACC& PRESENT(F_RW,F_RACTHE) &
    !$ACC& PRESENT(D_MYMS,D_NUMP,R_NDGNH,R_NDGL,G_NDGLU) &
    !$ACC& PRESENT(D_NPNTGTB1,D_OFFSETS_GEMM1,FOUBUF)

    !$ACC PARALLEL LOOP DEFAULT(NONE) COLLAPSE(3) PRIVATE(KM,ISL,IGLS,OFFSET1,OFFSET2,JGL,PAIA,PAIS) &
    !$ACC&              FIRSTPRIVATE(KF_FS,KF_UV,IIN_STRIDES0,IIN0_STRIDES0) ASYNC(1)
#endif
    DO KMLOC=1,D_NUMP
      DO JGL=1,R_NDGNH
        DO JF=1,KF_FS*2
          KM = D_MYMS(KMLOC)
          ISL = R_NDGNH-G_NDGLU(KM)+1
          IF (JGL >= ISL) THEN
            !(DO JGL=ISL,R_NDGNH)
            IGLS = R_NDGL+1-JGL
            OFFSET1 = D_NPNTGTB1(KMLOC,JGL )*2*KF_FS
            OFFSET2 = D_NPNTGTB1(KMLOC,IGLS)*2*KF_FS
            PAIA = FOUBUF(OFFSET1+JF)-FOUBUF(OFFSET2+JF)
            PAIS = FOUBUF(OFFSET1+JF)+FOUBUF(OFFSET2+JF)
            IF (JF <= 4*KF_UV) THEN
                ! Multiply in case of velocity
              PAIA = PAIA*REAL(F_RACTHE(JGL),JPRBT)
              PAIS = PAIS*REAL(F_RACTHE(JGL),JPRBT)
            ENDIF
            IF (KM /= 0) THEN
              ZINPA(JF+(JGL-ISL)*IIN_STRIDES0+IIN_STRIDES0*D_OFFSETS_GEMM1(KMLOC))=PAIA*REAL(F_RW(JGL),JPRBT)
              ZINPS(JF+(JGL-ISL)*IIN_STRIDES0+IIN_STRIDES0*D_OFFSETS_GEMM1(KMLOC))=PAIS*REAL(F_RW(JGL),JPRBT)
            ELSEIF (MOD(JF-1,2) == 0) THEN
              ! every other field is sufficient because Im(KM=0) == 0
              ZINPA0((JF-1)/2+1+(JGL-1)*IIN0_STRIDES0)=PAIA*REAL(F_RW(JGL),JPRBT)
              ZINPS0((JF-1)/2+1+(JGL-1)*IIN0_STRIDES0)=PAIS*REAL(F_RW(JGL),JPRBT)
            ENDIF
          ENDIF
        ENDDO
      ENDDO
    END DO

#ifdef OMPGPU
#endif

#if defined(USE_CUTLASS) && defined(USE_CUTLASS_3XTF32)
#ifdef ACCGPU
    !$ACC PARALLEL LOOP DEFAULT(NONE) COLLAPSE(2) PRIVATE(KM,JGL) &
    !$ACC&              FIRSTPRIVATE(KF_FS,IIN_STRIDES0) ASYNC(1)
#endif
    DO KMLOC=1,D_NUMP
      DO JF=1,KF_FS*2
          KM = D_MYMS(KMLOC)
          !$ACC LOOP SEQ
          DO JGL=G_NDGLU(KM),ALIGN(G_NDGLU(KM),A)-1
            IF (KM /= 0) THEN
              ZINPA(JF+JGL*IIN_STRIDES0+IIN_STRIDES0*D_OFFSETS_GEMM1(KMLOC))=0.0_JPRB
              ZINPS(JF+JGL*IIN_STRIDES0+IIN_STRIDES0*D_OFFSETS_GEMM1(KMLOC))=0.0_JPRB
            ENDIF
          ENDDO
      ENDDO
    END DO
#endif

#ifdef OMPGPU
#endif
#ifdef ACCGPU
    !$ACC END DATA
#endif

  END SUBROUTINE TRLTOM_UNPACK

END MODULE TRLTOM_PACK_UNPACK

